/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*!
 * \file pp_matmul_f16_nz_kernel.h
 * \brief
 */
#ifndef __PP_MAT_MUL_F16_NZ_KERNEL_H__
#define __PP_MAT_MUL_F16_NZ_KERNEL_H__

#ifdef __CCE_KT_TEST__
#include "stub_def.h"
#include "stub_fun.h"
#else
#define __aicore__ [aicore]
#endif

#include "utils/kernel/common.h"
#include "utils/kernel/mem.h"
#include "utils/kernel/iterator.h"
#include "utils/kernel/mma.h"
#include "utils/kernel/utils.h"
#include "pp_mat_mul_common.h"

namespace PpMatMulNS {

#ifdef __DAV_M200__
struct OpShapeNz {
    uint32_t batchSize{0};
    uint32_t m{0};
    uint32_t k{0};
    uint32_t n{0};
    uint32_t m0{0};
    uint32_t k0{0};
    uint32_t n0{0};
};

struct PpTilingDataNz {
    OpShapeNz opShape;
    uint32_t mLoop{1};
    uint32_t kLoop{1};
    uint32_t nLoop{1};
    uint32_t coreLoop{1};
    uint32_t swizzlCount{1};
    uint32_t tilingKey{0};
    uint32_t blockDim{1};
    uint32_t swizzlDirect{0};
    uint32_t splitk{0};
};

template <uint32_t SwizzleDirect, bool TA, bool TB, typename InDtype = half, typename OutDtype = half, typename AccumDtype = float, bool EnablePreload = false>
class PpMatmulF16NZ {
    using OnChipBuffer = AsdopsBuffer<ArchType::ASCEND_V200>;

public:
    __aicore__ explicit PpMatmulF16NZ(){};

    __aicore__ FORCE_INLINE void SetArgs(GM_ADDR a, GM_ADDR b, GM_ADDR c, const PpMatmulTilingData* tilingData) {
        gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(a));
        gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(b));
        gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(c));
        batch_size = tilingData->batch;
        m = RoundUp<CONST_16>(tilingData->m);
        k = RoundUp<CONST_16>(tilingData->k);
        n = RoundUp<CONST_16>(tilingData->n);
        m0 = tilingData->m0;
        k0 = tilingData->k0;
        n0 = tilingData->n0;
        m_loop = tilingData->mLoop;
        k_loop = tilingData->kLoop;
        n_loop = tilingData->nLoop;
        core_loop = tilingData->coreLoop;
        swizzl_count = tilingData->swizzlCount;
        core_num = AscendC::GetBlockNum();
        core_idx = AscendC::GetBlockIdx();
        ping_flag = 1;

        OnChipBuffer buf;
        l1_base_a = buf.template GetBuffer<BufferType::ASCEND_CB, InDtype>(0);
        l1_base_b = buf.template GetBuffer<BufferType::ASCEND_CB, InDtype>(RoundUp<CONST_256>(m0 * k0 * sizeof(InDtype)));
        l0a_base = buf.template GetBuffer<BufferType::ASCEND_L0A, InDtype>(0);
        l0b_base = buf.template GetBuffer<BufferType::ASCEND_L0B, InDtype>(0);
        l0c_buf = buf.template GetBuffer<BufferType::ASCEND_L0C, AccumDtype>(0);
        ub_c = buf.template GetBuffer<BufferType::ASCEND_UB, OutDtype>(0);
    }

    __aicore__ FORCE_INLINE void GetTileIdx(const uint32_t loop_idx, uint64_t &m_idx, uint64_t &n_idx) {
        uint32_t in_batch_idx = loop_idx % (m_loop * n_loop);
        if constexpr (SwizzleDirect == 0) { // Zn
            uint32_t tile_block_loop = (m_loop + swizzl_count - 1) / swizzl_count;
            uint32_t tile_block_idx = in_batch_idx / (swizzl_count * n_loop);
            uint32_t in_tile_block_idx = in_batch_idx % (swizzl_count * n_loop);

            uint32_t n_row = swizzl_count;
            if (tile_block_idx == tile_block_loop - 1) {
                n_row = m_loop - swizzl_count * tile_block_idx;
            }
            m_idx = tile_block_idx * swizzl_count + in_tile_block_idx % n_row;
            n_idx = in_tile_block_idx / n_row;
            if (tile_block_idx % 2 != 0) {
                n_idx = n_loop - n_idx - 1;
            }
        } else if constexpr (SwizzleDirect == 1) { // Nz
            uint32_t tile_block_loop = (n_loop + swizzl_count - 1) / swizzl_count;
            uint32_t tile_block_idx = in_batch_idx / (swizzl_count * m_loop);
            uint32_t in_tile_block_idx = in_batch_idx % (swizzl_count * m_loop);

            uint32_t n_col = swizzl_count;
            if (tile_block_idx == tile_block_loop - 1) {
                n_col = n_loop - swizzl_count * tile_block_idx;
            }
            m_idx = in_tile_block_idx / n_col;
            n_idx = tile_block_idx * swizzl_count + in_tile_block_idx % n_col;
            if (tile_block_idx % 2 != 0) {
                m_idx = m_loop - m_idx - 1;
            }
        }
    }

    __aicore__ FORCE_INLINE void Run() {

        using LocalTensor = AscendC::LocalTensor<InDtype>;
        using CopyGmToCbuf = gm_to_l1<ArchType::ASCEND_V200, InDtype, DataFormat::NZ, DataFormat::NZ>;
        using LoadCbufToCa = l1_to_l0_a<ArchType::ASCEND_V200, InDtype, TA, DataFormat::ZN, DataFormat::ZZ>;
        using LoadCbufToCb = l1_to_l0_b<ArchType::ASCEND_V200, InDtype, TB, DataFormat::ZN, DataFormat::NZ>;
        using Mmad = mmad<ArchType::ASCEND_V200, InDtype, InDtype, AccumDtype, false>;
        using CopyMatrixCcToUbuf = l0c_to_ub<ArchType::ASCEND_V200, AccumDtype, OutDtype, true>;
        using CopyUbufToGm = ub_to_gm<ArchType::ASCEND_V200, InDtype, DataFormat::NZ, DataFormat::NZ>;
        SET_FLAG(MTE1, MTE2, EVENT_ID0);
        SET_FLAG(MTE1, MTE2, EVENT_ID1);
        SET_FLAG(MTE1, MTE2, EVENT_ID2);
        SET_FLAG(MTE1, MTE2, EVENT_ID3);
        SET_FLAG(M, MTE1, EVENT_ID0);
        SET_FLAG(M, MTE1, EVENT_ID1);
        SET_FLAG(V, M, EVENT_ID0);
        SET_FLAG(MTE3, V, EVENT_ID0);
        for (uint32_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += core_num) {
            uint64_t batch_idx = loop_idx / (m_loop * n_loop);
            uint64_t m_idx = 0, n_idx = 0;
            GetTileIdx(loop_idx, m_idx, n_idx);
            uint64_t offset_a = 0, offset_b = 0, offset_a_next = 0, offset_b_next = 0;
            uint64_t offset_c = batch_idx * m * n + n_idx * n0 * m + m_idx * m0 * CONST_16;
            uint32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0;
            uint32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0;
            uint32_t m_round = RoundUp<CONST_16>(m_actual);
            uint32_t n_round = RoundUp<CONST_16>(n_actual);
            uint32_t mn_max = m_round > n_round ? m_round : n_round;
            uint32_t k_part_len = L0_PINGPONG_BUFFER_SIZE / mn_max / CONST_16 * CONST_16;

            if constexpr (EnablePreload) {
                LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
                LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
                auto event_id_l1_a = ping_flag ? EVENT_ID0 : EVENT_ID1;
                auto event_id_l1_b = ping_flag ? EVENT_ID2 : EVENT_ID3;
                if (loop_idx == core_idx) {
                    uint32_t k_actual = ((k_loop == 1) ? k : k0);
                    uint32_t k_round = RoundUp<CONST_16>(k_actual);
                    WAIT_FLAG(MTE1, MTE2, event_id_l1_a);
                    if constexpr (TA) {
                        offset_a = batch_idx * m * k + m_idx * m0 * k;
                        CopyGmToCbuf(l1_buf_a,       // src
                                    gm_a[offset_a], // dst
                                    k_actual,       // nTileActual
                                    k_round,        // nTileCeil
                                    k,              // nVal
                                    m_actual,       // dTileActual
                                    m_round,        // dTileCeil
                                    m);             // dVal
                    } else {
                        offset_a = batch_idx * m * k + m_idx * m0 * CONST_16;
                        CopyGmToCbuf(l1_buf_a,       // src
                                    gm_a[offset_a], // dst
                                    m_actual,       // nTileActual
                                    m_round,        // nTileCeil
                                    m,              // nVal
                                    k_actual,       // dTileActual
                                    k_round,        // dTileCeil
                                    k);             // dVal
                    }
                    SET_FLAG(MTE2, MTE1, event_id_l1_a);

                    WAIT_FLAG(MTE1, MTE2, event_id_l1_b);
                    if constexpr (TB) {
                        offset_b = batch_idx * n * k + n_idx * n0 * CONST_16;
                        CopyGmToCbuf(l1_buf_b,       // src
                                    gm_b[offset_b], // dst
                                    n_actual,       // nTileActual
                                    n_round,        // nTileCeil
                                    n,              // nVal
                                    k_actual,       // dTileActual
                                    k_round,        // dTileCeil
                                    k);             // dVal
                    } else {
                        offset_b = batch_idx * n * k + n_idx * n0 * k;
                        CopyGmToCbuf(l1_buf_b,       // src
                                    gm_b[offset_b], // dst
                                    k_actual,       // nTileActual
                                    k_round,        // nTileCeil
                                    k,              // nVal
                                    n_actual,       // dTileActual
                                    n_round,        // dTileCeil
                                    n);             // dVal
                    }
                    SET_FLAG(MTE2, MTE1, event_id_l1_b);
                }
            }

            for (uint64_t k_idx = 0; k_idx < k_loop; ++k_idx) {
                uint32_t k_actual = (k_idx == (k_loop - 1)) ? (k - k_idx * k0) : k0;
                uint32_t k_round = RoundUp<CONST_16>(k_actual);
                LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
                LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
                auto event_id_l1_a = ping_flag ? EVENT_ID0 : EVENT_ID1;
                auto event_id_l1_b = ping_flag ? EVENT_ID2 : EVENT_ID3;
                if constexpr (EnablePreload) {
                    if (k_idx + 1 < k_loop) {
                        uint64_t k_idx_next = k_idx + 1;
                        uint32_t k_actual_next = (k_idx_next == k_loop - 1) ? (k - k_idx_next * k0) : k0;
                        uint32_t k_round_next = RoundUp<CONST_16>(k_actual_next);

                        LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
                        LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
                        event_t event_id_l1_a_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1;
                        event_t event_id_l1_b_next = (1 - ping_flag) ? EVENT_ID2 : EVENT_ID3;
                        WAIT_FLAG(MTE1, MTE2, event_id_l1_a_next);
                        if constexpr (TA) {
                            offset_a_next = batch_idx * m * k + m_idx * m0 * k + k_idx_next * k0 * CONST_16;
                            CopyGmToCbuf(l1_buf_a_next,       // src
                                        gm_a[offset_a_next], // dst
                                        k_actual_next,       // nTileActual
                                        k_round_next,        // nTileCeil
                                        k,                   // nVal
                                        m_actual,            // dTileActual
                                        m_round,             // dTileCeil
                                        m);                  // dVal
                        } else {
                            offset_a_next = batch_idx * m * k + k_idx_next * k0 * m + m_idx * m0 * CONST_16;
                            CopyGmToCbuf(l1_buf_a_next,       // src
                                        gm_a[offset_a_next], // dst
                                        m_actual,            // nTileActual
                                        m_round,             // nTileCeil
                                        m,                   // nVal
                                        k_actual_next,       // dTileActual
                                        k_round_next,        // dTileCeil
                                        k);                  // dVal
                        }
                        SET_FLAG(MTE2, MTE1, event_id_l1_a_next);

                        WAIT_FLAG(MTE1, MTE2, event_id_l1_b_next);
                        if constexpr (TB) {
                            offset_b_next = batch_idx * n * k + k_idx_next * k0 * n + n_idx * n0 * CONST_16;
                            CopyGmToCbuf(l1_buf_b_next,       // src
                                        gm_b[offset_b_next], // dst
                                        n_actual,            // nTileActual
                                        n_round,             // nTileCeil
                                        n,                   // nVal
                                        k_actual_next,       // dTileActual
                                        k_round_next,        // dTileCeil
                                        k);                  // dVal
                        } else {
                            offset_b_next = batch_idx * n * k + n_idx * n0 * k + k_idx_next * k0 * CONST_16;
                            CopyGmToCbuf(l1_buf_b_next,       // src
                                        gm_b[offset_b_next], // dst
                                        k_actual_next,       // nTileActual
                                        k_round_next,        // nTileCeil
                                        k,                   // nVal
                                        n_actual,            // dTileActual
                                        n_round,             // dTileCeil
                                        n);                  // dVal
                        }
                        SET_FLAG(MTE2, MTE1, event_id_l1_b_next);
                    }

                    if (k_idx + 1 == k_loop && loop_idx + core_num < core_loop) {
                        uint64_t batch_idx_next = (loop_idx + core_num) / (m_loop * n_loop);
                        uint64_t m_idx_next = 0, n_idx_next = 0;
                        GetTileIdx(loop_idx + core_num, m_idx_next, n_idx_next);
                        uint32_t m_actual_next = (m_idx_next == (m_loop - 1)) ? (m - m_idx_next * m0) : m0;
                        uint32_t k_actual_next = (k_loop == 1) ? k : k0;
                        uint32_t n_actual_next = (n_idx_next == (n_loop - 1)) ? (n - n_idx_next * n0) : n0;
                        uint32_t m_round_next = RoundUp<CONST_16>(m_actual_next);
                        uint32_t n_round_next = RoundUp<CONST_16>(n_actual_next);
                        uint32_t k_round_next = RoundUp<CONST_16>(k_actual_next);

                        LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
                        LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
                        event_t event_id_l1_a_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1;
                        event_t event_id_l1_b_next = (1 - ping_flag) ? EVENT_ID2 : EVENT_ID3;
                        WAIT_FLAG(MTE1, MTE2, event_id_l1_a_next);
                        if constexpr (TA) {
                            offset_a_next = batch_idx_next * m * k + m_idx_next * m0 * k;
                            CopyGmToCbuf(l1_buf_a_next,       // src
                                        gm_a[offset_a_next], // dst
                                        k_actual_next,       // nTileActual
                                        k_round_next,        // nTileCeil
                                        k,                   // nVal
                                        m_actual_next,       // dTileActual
                                        m_round_next,        // dTileCeil
                                        m);                  // dVal
                        } else {
                            offset_a_next = batch_idx_next * m * k + m_idx_next * m0 * CONST_16;
                            CopyGmToCbuf(l1_buf_a_next,       // src
                                        gm_a[offset_a_next], // dst
                                        m_actual_next,       // nTileActual
                                        m_round_next,        // nTileCeil
                                        m,                   // nVal
                                        k_actual_next,       // dTileActual
                                        k_round_next,        // dTileCeil
                                        k);                  // dVal
                        }
                        SET_FLAG(MTE2, MTE1, event_id_l1_a_next);

                        WAIT_FLAG(MTE1, MTE2, event_id_l1_b_next);
                        if constexpr (TB) {
                            offset_b_next = batch_idx_next * n * k + n_idx_next * n0 * CONST_16;
                            CopyGmToCbuf(l1_buf_b_next,       // src
                                        gm_b[offset_b_next], // dst
                                        n_actual_next,       // nTileActual
                                        n_round_next,        // nTileCeil
                                        n,                   // nVal
                                        k_actual_next,       // dTileActual
                                        k_round_next,        // dTileCeil
                                        k);                  // dVal
                        } else {
                            offset_b_next = batch_idx_next * n * k + n_idx_next * n0 * k;
                            CopyGmToCbuf(l1_buf_b_next,       // src
                                        gm_b[offset_b_next], // dst
                                        k_actual_next,       // nTileActual
                                        k_round_next,        // nTileCeil
                                        k,                   // nVal
                                        n_actual_next,       // dTileActual
                                        n_round_next,        // dTileCeil
                                        n);                  // dVal
                        }
                        SET_FLAG(MTE2, MTE1, event_id_l1_b_next);
                    }
                } else {
                    WAIT_FLAG(MTE1, MTE2, event_id_l1_a);
                    if constexpr (TA) {
                        offset_a = batch_idx * m * k + m_idx * m0 * k + k_idx * k0 * CONST_16;
                        CopyGmToCbuf(l1_buf_a,       // src
                                    gm_a[offset_a], // dst
                                    k_actual,       // nTileActual
                                    k_round,        // nTileCeil
                                    k,              // nVal
                                    m_actual,       // dTileActual
                                    m_round,        // dTileCeil
                                    m);             // dVal
                    } else {
                        offset_a = batch_idx * m * k + k_idx * k0 * m + m_idx * m0 * CONST_16;
                        CopyGmToCbuf(l1_buf_a,       // src
                                    gm_a[offset_a], // dst
                                    m_actual,       // nTileActual
                                    m_round,        // nTileCeil
                                    m,              // nVal
                                    k_actual,       // dTileActual
                                    k_round,        // dTileCeil
                                    k);             // dVal
                    }
                    SET_FLAG(MTE2, MTE1, event_id_l1_a);

                    WAIT_FLAG(MTE1, MTE2, event_id_l1_b);
                    if constexpr (TB) {
                        offset_b = batch_idx * n * k + k_idx * k0 * n + n_idx * n0 * CONST_16;
                        CopyGmToCbuf(l1_buf_b,       // src
                                    gm_b[offset_b], // dst
                                    n_actual,       // nTileActual
                                    n_round,        // nTileCeil
                                    n,              // nVal
                                    k_actual,       // dTileActual
                                    k_round,        // dTileCeil
                                    k);             // dVal
                    } else {
                        offset_b = batch_idx * n * k + n_idx * n0 * k + k_idx * k0 * CONST_16;
                        CopyGmToCbuf(l1_buf_b,       // src
                                    gm_b[offset_b], // dst
                                    k_actual,       // nTileActual
                                    k_round,        // nTileCeil
                                    k,              // nVal
                                    n_actual,       // dTileActual
                                    n_round,        // dTileCeil
                                    n);             // dVal
                    }
                    SET_FLAG(MTE2, MTE1, event_id_l1_b);
                }
                uint32_t k_part_loop = (k_actual + k_part_len - 1) / k_part_len;
                for (uint32_t k_part_idx = 0; k_part_idx < k_part_loop; ++k_part_idx) {
                    uint32_t k0_round = (k_part_idx < k_part_loop - 1) ? k_part_len : k_round - k_part_idx * k_part_len;
                    uint32_t k0_actual = (k_part_idx < k_part_loop - 1) ? k_part_len : k_actual - k_part_idx * k_part_len;

                    auto l0_event_id = (1 - k_part_idx & 0x1) ? EVENT_ID0 : EVENT_ID1;
                    LocalTensor l0a_buf = l0a_base[(k_part_idx & 0x1) * L0_PINGPONG_BUFFER_SIZE];
                    LocalTensor l0b_buf = l0b_base[(k_part_idx & 0x1) * L0_PINGPONG_BUFFER_SIZE];

                    // L1 -> L0A
                    if (k_part_idx == 0) {
                        WAIT_FLAG(MTE2, MTE1, event_id_l1_a);
                    }
                    WAIT_FLAG(M, MTE1, l0_event_id);
                    if constexpr (TA) {
                        LoadCbufToCa(l0a_buf,                                      // l0Tensor
                                    l1_buf_a[k_part_idx * k_part_len * CONST_16], // l1Tensor
                                    m_round,                                      // mTileCeil
                                    k0_round,                                     // kPartCeil
                                    k_round / CONST_16,                           // mSrcStride
                                    1,                                            // kSrcStride
                                    k0_round / CONST_16,                          // mDstStride
                                    1);                                           // kDstStride
                    } else {
                        LoadCbufToCa(l0a_buf,                                     // l0Tensor
                                    l1_buf_a[k_part_idx * k_part_len * m_round], // l1Tensor
                                    m_round,                                     // mTileCeil
                                    k0_round,                                    // kPartCeil
                                    1,                                           // mSrcStride
                                    m_round / CONST_16,                          // kSrcStride
                                    k0_round / CONST_16,                         // mDstStride
                                    1);                                          // kDstStride
                    }
                    if (k_part_idx == k_part_loop - 1) {
                        SET_FLAG(MTE1, MTE2, event_id_l1_a);
                    }

                    // L1 -> L0B
                    if (k_part_idx == 0) {
                        WAIT_FLAG(MTE2, MTE1, event_id_l1_b);
                    }
                    if constexpr (TB) {
                        LoadCbufToCb(l0b_buf,                                     // l0Tensor
                                    l1_buf_b[k_part_idx * k_part_len * n_round], // l1Tensor
                                    n_round,                                     // nTileCeil
                                    k0_round,                                    // kPartCeil
                                    1,                                           // nSrcStride
                                    n_round / CONST_16,                          // kSrcStride
                                    1,                                           // nDstStride
                                    n_round / CONST_16);                         // kDstStride
                    } else {
                        LoadCbufToCb(l0b_buf,                                      // l0Tensor
                                    l1_buf_b[k_part_idx * k_part_len * CONST_16], // l1Tensor
                                    n_round,                                      // nTileCeil
                                    k0_round,                                     // kPartCeil
                                    k_round / CONST_16,                           // nSrcStride
                                    1,                                            // kSrcStride
                                    1,                                            // nDstStride
                                    n_round / CONST_16);                          // kDstStride
                    }
                    if (k_part_idx == k_part_loop - 1) {
                        SET_FLAG(MTE1, MTE2, event_id_l1_b);
                    }

                    SET_FLAG(MTE1, M, l0_event_id);
                    WAIT_FLAG(MTE1, M, l0_event_id);

                    bool init_c = (k_idx == 0 && k_part_idx == 0);
                    if (init_c) {
                        WAIT_FLAG(V, M, EVENT_ID0);
                    }
                    PIPE_BARRIER(M);
                    Mmad(l0c_buf,   // c
                        l0a_buf,   // a
                        l0b_buf,   // b
                        m_actual,  // mTileActual
                        n_actual,  // nTileActual
                        k0_actual, // kTileActual
                        init_c);   // initC
                    SET_FLAG(M, MTE1, l0_event_id);
                }

                ping_flag = 1 - ping_flag;
            }

            // copy from L0C to gm
            SET_FLAG(M, V, EVENT_ID0);
            WAIT_FLAG(M, V, EVENT_ID0);
            WAIT_FLAG(MTE3, V, EVENT_ID0);
            CopyMatrixCcToUbuf(ub_c,               // dst
                            l0c_buf,            // src
                            n_round / CONST_16, // nBurst
                            m_round / CONST_16, // lenBurst
                            0,                  // srcStride
                            0);                 // dstStride
            SET_FLAG(V, M, EVENT_ID0);

            SET_FLAG(V, MTE3, EVENT_ID0);
            WAIT_FLAG(V, MTE3, EVENT_ID0);

            CopyUbufToGm(gm_c[offset_c], // dst
                        ub_c,           // src
                        m_actual,       // nTileActual
                        m_round,        // nTileCeil
                        m,              // nVal
                        n_actual,       // dTileActual
                        n_round,        // dTileCeil
                        n);             // dVal
            SET_FLAG(MTE3, V, EVENT_ID0);
        }
        WAIT_FLAG(MTE1, MTE2, EVENT_ID0);
        WAIT_FLAG(MTE1, MTE2, EVENT_ID1);
        WAIT_FLAG(MTE1, MTE2, EVENT_ID2);
        WAIT_FLAG(MTE1, MTE2, EVENT_ID3);
        WAIT_FLAG(M, MTE1, EVENT_ID0);
        WAIT_FLAG(M, MTE1, EVENT_ID1);
        WAIT_FLAG(V, M, EVENT_ID0);
        WAIT_FLAG(MTE3, V, EVENT_ID0);

        PIPE_BARRIER(ALL);
    }

private:
    AscendC::GlobalTensor<InDtype> gm_a;
    AscendC::GlobalTensor<InDtype> gm_b;
    AscendC::GlobalTensor<OutDtype> gm_c;
    AscendC::LocalTensor<InDtype> l1_base_a;
    AscendC::LocalTensor<InDtype> l1_base_b;
    AscendC::LocalTensor<InDtype> l0a_base;
    AscendC::LocalTensor<InDtype> l0b_base;
    AscendC::LocalTensor<AccumDtype> l0c_buf;
    AscendC::LocalTensor<OutDtype> ub_c;
    uint32_t core_num{0};
    uint32_t batch_size{0};
    uint32_t m{0};
    uint32_t k{0};
    uint32_t n{0};
    uint32_t m0{0};
    uint32_t k0{0};
    uint32_t n0{0};
    uint32_t m_loop{0};
    uint32_t n_loop{0};
    uint32_t k_loop{0};
    uint32_t core_loop{0};
    uint32_t core_idx{0};
    uint32_t ping_flag{0};
    uint32_t swizzl_count{0};
};
#endif
}
#endif